package cn.jiangzeyin.controller.classifiers;

import cn.hutool.core.convert.Convert;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.StrUtil;
import cn.jiangzeyin.CommonProperties;
import cn.jiangzeyin.common.JsonMessage;
import com.alibaba.fastjson.JSONArray;
import com.hankcs.hanlp.classification.classifiers.IClassifier;
import com.hankcs.hanlp.classification.classifiers.NaiveBayesClassifier;
import com.hankcs.hanlp.classification.models.NaiveBayesModel;
import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.utility.GlobalObjectPool;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;

import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Created by jiangzeyin on 2017/12/9.
 */
@RestController
@RequestMapping("classifiers")
public class ClassificationPredict extends BaseClassification {
    private static String modelFolder;
    private static final HashMap<String, Long> MODEL_FILE_LAST_MODIFIED = new HashMap<>();

    /**
     * 获取文本的分类
     *
     * @param modelName 使用model 文件
     * @param text      分类
     * @param size      获取分值最高的前几个
     * @return json
     */
    @RequestMapping(value = "predict.json", method = RequestMethod.POST, produces = MediaType.APPLICATION_JSON_UTF8_VALUE)
    public String train(String modelName, String text, String size) {
        if (StrUtil.isEmpty(modelName)) {
            return JsonMessage.getString(400, "modelName is null");
        }
        if (StrUtil.isEmpty(text)) {
            return JsonMessage.getString(400, "text is null");
        }
        int sizeInt = Convert.toInt(size, 5);
        if (sizeInt < 1) {
            return JsonMessage.getString(400, "size must >=1");
        }
        modelName = replacePath(modelName);
        // 模型文件存放路径
        if (StrUtil.isEmpty(modelFolder)) {
            String[] result = getModelFolder();
            if (result[0] != null) {
                return result[0];
            }
            modelFolder = result[1];
        }
        String path = modelFolder + "/" + modelName;
        path = FileUtil.normalize(path);
        // 获取缓存中的数据
        NaiveBayesModel abstractModel = GlobalObjectPool.get(CommonProperties.CACHE_PREFIX + path);
        File modelFile = new File(path);
        if (!modelFile.exists() || !modelFile.isFile()) {
            return JsonMessage.getString(404, modelName + " not exists or not isFile");
        }
        long lastModified = modelFile.lastModified();
        // 判断模型文件是否被修改
        Long oldLastModified = MODEL_FILE_LAST_MODIFIED.get(path);
        if (oldLastModified == null) {
            oldLastModified = 0L;
        }
        // 重新读取模型文件
        if (abstractModel == null || lastModified != oldLastModified) {
            abstractModel = (NaiveBayesModel) IOUtil.readObjectFrom(path);
            GlobalObjectPool.put(CommonProperties.CACHE_PREFIX + path, abstractModel);
            MODEL_FILE_LAST_MODIFIED.put(path, lastModified);
        }
        IClassifier classifier = new NaiveBayesClassifier(abstractModel);
        Map<String, Double> map = classifier.predict(text);
        List<Map.Entry<String, Double>> entryList = new ArrayList<>(map.entrySet());
        entryList.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));
        int len = entryList.size() > sizeInt ? sizeInt : entryList.size();
        JSONArray jsonArray = new JSONArray();
        for (int i = 0; i < len; i++) {
            jsonArray.add(entryList.get(i));
        }
        return JsonMessage.getString(200, "success", jsonArray);
    }
}
